Toxic Comment Filter

BiLSTM model for multi label classification
code
Deep Learning
Python, R
Author

Simone Brazzi

Published

August 12, 2024

1 Introduction

  • Costruire un modello in grado di filtrare i commenti degli utenti in base al grado di dannosità del linguaggio.
  • Preprocessare il testo eliminando l’insieme di token che non danno contributo significativo a livello semantico.
  • Trasformare il corpus testuale in sequenze.
  • Costruire un modello di Deep Learning comprendente dei layer ricorrenti per un task di classificazione multilabel.

In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.

2 Setup

Leveraging Quarto and RStudio, I will setup an R and Python enviroment.

2.1 Import R libraries

Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.

Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)

reticulate::use_virtualenv("r-tf")

2.2 Import Python packages

Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp

from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score


from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_score

Create a Config class to store all the useful parameters for the model and for the project.

2.3 Class Config

I created a class with all the basic configuration of the model, to improve the readability.

Code
class Config():
    def __init__(self):
        self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
        self.max_tokens = 20000
        self.output_sequence_length = 911 # check the analysis done to establish this value
        self.embedding_dim = 128
        self.batch_size = 32
        self.epochs = 100
        self.temp_split = 0.3
        self.test_split = 0.5
        self.random_state = 42
        self.total_samples = 159571 # total train samples
        self.train_samples = 111699
        self.val_samples = 23936
        self.features = 'comment_text'
        self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
        self.label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
        self.model =  self.path + "model_f1.keras"
        self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
        self.history = self.path + "lstm_model_f1.xlsx"
        
        self.metrics = [
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
            F1Score(name="f1", average="macro")
            
        ]
    def get_early_stopping(self):
        early_stopping = keras.callbacks.EarlyStopping(
            monitor="val_f1", # "val_recall",
            min_delta=0.2,
            patience=10,
            verbose=0,
            mode="max",
            restore_best_weights=True,
            start_from_epoch=3
        )
        return early_stopping

    def get_model_checkpoint(self, filepath):
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath=filepath,
            monitor="val_f1", # "val_recall",
            verbose=0,
            save_best_only=True,
            save_weights_only=False,
            mode="max",
            save_freq="epoch"
        )
        return model_checkpoint

    def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):

      # instantiate KFold
      kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
      threshold_scores = []

      for threshold in thresholds:

        cv_scores = []
        for train_index, val_index in kf.split(ytrue):

          ytrue_val = ytrue[val_index]
          yproba_val = yproba[val_index]

          ypred_val = (yproba_val >= threshold).astype(int)
          score = metric(ytrue_val, ypred_val, average="macro")
          cv_scores.append(score)

        mean_score = np.mean(cv_scores)
        threshold_scores.append((threshold, mean_score))

        # Find the threshold with the highest mean score
        best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
      return best_threshold, best_score

config = Config()

3 Data

The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.

Code
# df = pd.read_csv(config.path)
file = tf.keras.utils.get_file("Filter_Toxic_Comments_dataset.csv", config.url)
df = pd.read_csv(file)
Code
library(reticulate)

py$df %>%
  tibble() %>% 
  head(5)
Table 1: First 5 elemtns
# A tibble: 5 × 8
  comment_text            toxic severe_toxic obscene threat insult identity_hate
  <chr>                   <dbl>        <dbl>   <dbl>  <dbl>  <dbl>         <dbl>
1 "Explanation\nWhy the …     0            0       0      0      0             0
2 "D'aww! He matches thi…     0            0       0      0      0             0
3 "Hey man, I'm really n…     0            0       0      0      0             0
4 "\"\nMore\nI can't mak…     0            0       0      0      0             0
5 "You, sir, are my hero…     0            0       0      0      0             0
# ℹ 1 more variable: sum_injurious <dbl>

Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.

Code
df.loc[df.sum_injurious == 0, "clean"] = 1
df.loc[df.sum_injurious != 0, "clean"] = 0

3.1 EDA

First a check on the dataset to find possible missing values and imbalances.

3.1.1 Frequency

Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels

df_r_grouped <- df_r %>% 
  select(all_of(new_labels_r)) %>%
  pivot_longer(
    cols = all_of(new_labels_r),
    names_to = "label",
    values_to = "value"
  ) %>% 
  group_by(label) %>%
  summarise(count = sum(value)) %>% 
  mutate(freq = round(count / sum(count), 4))

df_r_grouped
Table 2: Absolute and relative labels frequency
# A tibble: 7 × 3
  label          count   freq
  <chr>          <dbl>  <dbl>
1 clean         143346 0.803 
2 identity_hate   1405 0.0079
3 insult          7877 0.0441
4 obscene         8449 0.0473
5 severe_toxic    1595 0.0089
6 threat           478 0.0027
7 toxic          15294 0.0857

3.1.2 Barchart

Code
library(reticulate)
barchart <- df_r_grouped %>%
  ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
  geom_col() +
  labs(
    x = "Labels",
    y = "Count"
  ) +
  # sort bars in descending order
  scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
  scale_fill_brewer(type = "seq", palette = "RdYlBu") +
  theme_minimal()
ggplotly(barchart)
Figure 1: Imbalance in the dataset with clean variable

It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.

It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.

3.2 Sequence lenght definition

To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.

One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.

3.2.1 Summary

Code
library(reticulate)
df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  pull(text_length) %>% 
  summary() %>% 
  as.list() %>% 
  as_tibble()
Table 3: Summary of text length
# A tibble: 1 × 6
   Min. `1st Qu.` Median  Mean `3rd Qu.`  Max.
  <dbl>     <dbl>  <dbl> <dbl>     <dbl> <dbl>
1     4        91    196  378.       419  5000

3.2.2 Boxplot

Code
library(reticulate)
boxplot <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  # pull(text_length) %>% 
  ggplot(aes(y = text_length)) +
  geom_boxplot() +
  theme_minimal()
ggplotly(boxplot)
Figure 2: Text length boxplot

3.2.3 Histogram

Code
library(reticulate)
df_ <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
  )

Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)

histogram <- df_ %>% 
  ggplot(aes(x = text_length)) +
  geom_histogram(bins = 50) +
  geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
  theme_minimal() +
  xlab("Text Length") +
  ylab("Frequency") +
  xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)
Figure 3: Text length histogram with boxplot upper fence

Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.

3.3 Dataset

Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.

Code
x = df[config.features].values
y = df[config.labels].values

xtrain, xtemp, ytrain, ytemp = train_test_split(
  x,
  y,
  test_size=config.temp_split, # .3
  random_state=config.random_state
  )
xtest, xval, ytest, yval = train_test_split(
  xtemp,
  ytemp,
  test_size=config.test_split, # .5
  random_state=config.random_state
  )

xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape

The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.

Code
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtrain, ytrain))
    .shuffle(xtrain.shape[0])
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtest, ytest))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset
    .from_tensor_slices((xval, yval))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Code
print(
  f"train_ds cardinality: {train_ds.cardinality()}\n",
  f"val_ds cardinality: {val_ds.cardinality()}\n",
  f"test_ds cardinality: {test_ds.cardinality()}\n"
  )
train_ds cardinality: 3491
 val_ds cardinality: 748
 test_ds cardinality: 748

Check the first element of the dataset to be sure that the preprocessing is done correctly.

Code
train_ds.as_numpy_iterator().next()
(array([b"THAT'S RIGHT - LISTEN UP ASSHOLE !!!",
       b"As for educating editors at List of Spooks episodes, perhaps you could do that after viewing the MoS. If you have any problems, I can help out but it's not an article that I frequent.",
       b"Thanks \n\nThat was the IP's second go, thanks for fixing it. I wonder if it's related to your post to our pages? We must have something in common.",
       b'Moonriddengirl IS A STUPID SLUT AND SHOULD SUCK ON MY DICK, STOP DELETING MY ARTICLES AND MIND HER OWN BUSINESS...\n\nFUCK YOU WHORE!',
       b'"\nOf course it\'s censorship. You think censorship is bad, and because you think Black Kite\'s edit was justified it can\'t be censorship? What, are they issuing admin bits to moral morons, nowadays?\nAnd Black Kite\'s another. I don\'t care if he is an admin, he\'s too big for his britches if he thinks it\'s ok to put a snarky ""m"" on any edit CENSORING another Wikipedian.   "',
       b'"\n\n Happy Holidays \n\n Happy Holidays, fellow Wikipedian. ) TALK "',
       b'Being corrected when you are absurdly wrong makes you feel like the centre of attention, does it?  201.215.187.159',
       b'"\n\n Regarding Greek Muslim Delusional Paranoia Syndrome \n\n""Considering the latest developments the embargo on the Turkish Cypriot side sounds totally uncalled for to the average Turk. In rejection in the union, the existence political entity in the Northern Cyprus is being ackowledged and the international recognition should follow.""\n\n- Hahahahahaha! You cannot be serious? Don\'t you understand how high levels decision in the EU Council of Ministers work? They don\'t work via QMV, they work via unanimity. Do you think such a decision will ever gain unanimity with the European Union? The Greek Muslim Republic of The Occupied Territories has, and never will be recognized. Turks are the eternal optimists, I\'ll give them that, I have never met a Turk who wasn\'t a Nationalist or an Islamist or a Militirist. The sad thing is Turkey\'s territorial integrity will eventually collapse in on itself, Turkey is an Eastern Country, yet it tried to have Eastern \'heart\', with the Western \'mind\', it went for double or nothing and now it will lose everything, you cannot rely on a far right military which despises EU demands to stop persecution of the Kurds to protect your Constitution Turkey, we are entering endgame, and I will be only too glad to gain revenge for what these filth did to my Uncle 31 years ago, I will bathe in your blood Turks, I will burn villages to the ground and pour salt over once fertile land. Turkey is a festering disease that should never have happened. \n\nMay Jesus Christ our Lord protect us as we smite our enemies and cast them out from our lands, may Jesus Christ give us the strength to commit ourselves to revenge, may God almighty give us the strength to become hateful creatures."',
       b"It is already linked to on the page. Anyone editing the article who had any sense would have thought of visiting that page first before making any edits. The facebook page was put there by the National Federation of Cypriots. It's own website is still under construction and the section on UK Cypriot history is incomplete.",
       b'"\n\n The ONLY worthwhile candidate. \n\nIt\'s true. Gravel is the only presidential candidate not spouting out the usual political BS (as well as Ron Paul). It only figures that he\'s receiving little to no attention. His accomplishments in the 60\'s and 70\'s with the War and the Pentagon Papers trumps the accomplishments of every other candidate, but that scumbag Giuliani is revered for his ""actions"" during 9/11. It\'s so depressing...ugh. This only justifies my cynicism even more. The majority of Americans are content with accepting the same garbage every election. If Clinton wins, that will mean for 24 years, the most powerful country in the world was controlled by two differen families. Disgusting.   "',
       b"This is Wiki. In Wiki, those rules don't apply. There is no debate over this. Also, I don't think it says that the woman most cover her head. It merely says that it's decent, but not as a requirement. Am I wrong? Even if I were wrong, this is still Wiki and we shouldn't discriminate based on someone's patriarchal beliefs.",
       b'So well referenced for fascist \n\nA so well referenced fascist needs to be part of the Fascist portal. Please do not remove with out s\\discussion and consensus and reason for not including it in such a well referenced article.',
       b'Black americans have a hive mind mentality and automatically switch political party preferences just like that.211.28.54.73',
       b'"*Please read the Discussion page for the Thaksin Shinawatra article where I give an explanation of why I wrote it that way.  The way you edited the article doesn\'t make any sense.  The data is cited to support the assertation that Thaksin, post-election, still retained high popularity, despite the PAD\'s attacks.  The original data used to support it had issues, and I replaced it with the poll data which you edited.  Admittedly, it is not perfect, but it is still better than the original.  \n\nPlease don\'t assume that since you think you are better educated, you are smarter and have better judgement than me, or the ""lemming-like"" majority of Thailand.  The system of power you advocate was overthrown in 1932.  And next time, please log your name.  In a free society, people should not be ashamed to have political views.  \n\n"',
       b"Check Yegor Gaidar's Pogib Imperii (Collapse of an Empire).",
       b'Dear Dr Suresh\n\nMy name is Alex Hankey, and I am currently working in Bangalore to develop a new institute in AYUSH at the Manipal Group, together with Professor Bhushan Patwardhan. \n\nI was very interested to read your Wikipedia article on Siddha. As a member of the editorial boards of 2 scientific journals, I would like to offer to edit the English for you, retaining the meaning and content. It would help the readability, and clarify certain things.\n\nPlease contact me on alexhankey@gmail.com if you would like me to do this.\n\nAlex Hankey\n\nP.S. I am a physicist by training, working on the biophysics of traditional medicine. I did my first degree at Cambridge, and my PhD at M.I.T. in Massacussetts, U.S.A. I have been working with traditional medicine since 1985, and have published major new theories of Tridosha, and Prana, the vital force(see the Journal of Alternative and Complementary Medicine, October 2001, February 2004, and several issues in 2005). I shall soon publish a theory of samaadi, and the soul body connection, based on my theory of vital force.',
       b'You are not following wikipedia rules, which go by published sources, with no source to contradict Winterberg (2004) then you have nothing but your own personal hot air.  Restore the link which you are censoring in blatant violation of the rules. 173.169.90.98',
       b"and you're not even going to link to a credible source,",
       b'Well, your statement does explain nothing but denial without your ground, so why explain why it is not your minor historical fact. I already provided the above citation.',
       b'"Welcome!\n\nHello, , and welcome to Wikipedia! Thank you for your contributions. I hope you like the place and decide to stay. Here are some pages that you might find helpful:\nThe five pillars of Wikipedia\nTutorial\nHow to edit a page\nHow to write a great article\nManual of Style\nI hope you enjoy editing here and being a Wikipedian! Please sign your messages on discussion pages using four tildes (~~~~); this will automatically insert your username and the date. If you need help, check out Wikipedia:Questions, ask me on , or ask your question on this page and then place {{helpme}} before the question. Again, welcome!    (Contact me) \n\nFriday\nThe reason that page hasn\'t been created is because the production has not been notable. Everything about the development of the film fits on the franchise page. See WP:NFF for more specifics, but unless there becomes in-deph discussion on the production (and I\'ve seen what\'s out there and there isn\'t enough) then we have to wait till the film is released.   (Contact me) "',
       b'greatest person ever to walk this earth',
       b"So what? Go away and stop harassing me. You're a bully and a stalker.-",
       b"There is absolutely no reason why the comment and link shouldn't be included. You can accuse me of and edit war but you have no hesitation in involving yourself in it, needlessly vandalising the article by deleting valid additions. You conduct is a complete disgrace to this community.",
       b"I'ma smack ya upside da head wit a shovel \n\nI'm takin ya down, boi.",
       b'"\n\nWell, I showed you that ""Kievan Rus"" is used in books more frequently than ""Ancient Rus"". How is this not ""convincimg""? Is the counting rigged? You may go check other encyclopedias and come back tell us what you find out.\n\nAs to redirects, you are right. No rules prohibit creation of redirects from sensible titles. ""Sensible"" is pretty much inclusive and a redirect from Ancient Rus is alredy created. The types of redirects that are not allowed would not be created in good faith anyway. Under whatever name you think this may be found in English or whatever name you think the reader  looking for this article might enter into a search string, redirect is always a good idea. What\'s not allowed are only redirects that are POV pushing through promotion of POV term or non-existing term. Like it\'s not OK to make Little Russian a redirect to Ukrainian and instead the redirect should be to Little Russia, an article about the etymology and usage of the term. You cannot also invent a term and start pushing it into usage by creating a redirect article under it. Most of the rest is not a problem. Ancient Rus redirect is certainly a good idea.  "',
       b'Question \n\nHow does one go about requesting/setting up a deletion sorting list (e.g., Wikipedia:WikiProject Deletion sorting/Sportspeople) for a new topic area?  I would like to set up a new deletion sort list for swimming.  Please let me know.',
       b'I am going to murder ZimZalaBim ST47 for being evil homosexual jews.',
       b'automatically produce your name and the date.',
       b'"\n\n Help with CSS Reference Manual article \n\nDear editor, this morning, I found a new article (CSS Reference Manual) initiated by another author (that I am not acquainted to). However, while the article was at that point (and this moment as well) clearly still \'work in progress\', another editor/commenter  had put an \'AfD\' on it. Because I feel a strong need for this new article to exist, I defended it strongly with valid arguments, even stating that I wanted to contribute to it myself. However, at the point where on-topic discussions based on argumentation seized to exist, the editor/commenter  started to attack me personally, by openly questioning my personal credibility.\nAt his talk-page, an earlier ban (half way of the page) is present based on personal attacks (although I am not fully sure that he was the one given the ban or that he was the one that gave the ban to someone else).\n\nNevertheless: because of this AfD, the article gets flamed now by other users that don\'t take a closer look at it, for no valid reason, just because it\'s currently labeled as \'AfD\'...\n\nMy question is therefore: could you help me out to let this article be enhanced properly to Wikipedia standards?\nBecause I strongly feel that there is a need for it, it deserves a decent chance, because I would not liketo see very hard work thrown down the drain because of ""a bully"" that was banned before for personal attacks, and because the article\'s initial author (again: not me) has done a good job in starting it."',
       b"I'm confused, too. It does look like Rubin crossed the line into TPm territory again, but I don't understand the sockpuppet allegation here. While the proxy editing may approach meatpuppetry, it's not sockpuppetry.",
       b'Thank you. I hope this edit warring will not escalate any further. Greetings.',
       b'Clearly, it does not help with ones mental organization and physical communcation just to do reading. On top of this it is necessary to do some writing.'],
      dtype=object), array([[1, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]))

And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).

Code
print(
  f"text train shape: {train_ds.as_numpy_iterator().next()[0].shape}\n",
  f" text train type: {train_ds.as_numpy_iterator().next()[0].dtype}\n",
  f"label train shape: {train_ds.as_numpy_iterator().next()[1].shape}\n",
  f"label train type: {train_ds.as_numpy_iterator().next()[1].dtype}\n"
  )
text train shape: (32,)
  text train type: object
 label train shape: (32, 6)
 label train type: int64

4 Preprocessing

Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.

For more reference, see the documentation at the following link.

Code
text_vectorization = TextVectorization(
  max_tokens=config.max_tokens,
  standardize="lower_and_strip_punctuation",
  split="whitespace",
  output_mode="int",
  output_sequence_length=config.output_sequence_length,
  pad_to_max_tokens=True
  )

# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)

This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.

To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.

Code
processed_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

5 Model

5.1 Definition

Define the model using the Functional API.

Code
def get_deeper_lstm_model():
    clear_session()
    inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
    embedding = Embedding(
        input_dim=config.max_tokens,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)
    x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
    x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
    # Global average pooling
    x = GlobalAveragePooling1D()(x)
    # Add regularization
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = LayerNormalization()(x)
    outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
    
    return model

lstm_model = get_deeper_lstm_model()
lstm_model.summary()

5.2 Callbacks

Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.

Code
# callbacks
my_es = config.get_early_stopping()
my_mc = config.get_model_checkpoint(filepath="/checkpoint.keras")
callbacks = [my_es, my_mc]

5.3 Final preparation before fit

Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.

Code
lab = pd.DataFrame(columns=config.labels, data=ytrain)
r = lab.sum() / len(ytrain)
class_weight = dict(zip(range(len(config.labels)), r))
df_class_weight = pd.DataFrame.from_dict(
  data=class_weight,
  orient='index',
  columns=['class_weight']
  )
df_class_weight.index = config.labels
Code
library(reticulate)
py$df_class_weight
Table 4: Class weight
              class_weight
toxic          0.095900590
severe_toxic   0.009928468
obscene        0.052757858
threat         0.003061800
insult         0.049132042
identity_hate  0.008710911

It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.

Code
steps_per_epoch = config.train_samples // config.batch_size
validation_steps = config.val_samples // config.batch_size

5.4 Fit

The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:

  • .repeat() ensure the model sees all the dataset.
  • epocs is set to 100.
  • validation_data has the same repeat.
  • callbacks are the one defined before.
  • class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.
  • steps_per_epoch and validation_steps depend on the use of repeat.
Code
history = model.fit(
  processed_train_ds.repeat(),
  epochs=config.epochs,
  validation_data=processed_val_ds.repeat(),
  callbacks=callbacks,
  class_weight=class_weight,
  steps_per_epoch=steps_per_epoch,
  validation_steps=validation_steps
  )

Now we can import the model and the history trained on Kaggle.

Code
model = load_model(filepath=config.model)
history = pd.read_excel(config.history)

5.5 Evaluate

Code

validation = model.evaluate(
  processed_val_ds.repeat(),
  steps=validation_steps, # 748
  verbose=0
  )
Code
val_metrics <- tibble(
  metric = c("loss", "precision", "recall", "auc", "f1_score"),
  value = py$validation
  )
val_metrics
Table 5: Model validation metric
# A tibble: 5 × 2
  metric     value
  <chr>      <dbl>
1 loss      0.0542
2 precision 0.789 
3 recall    0.671 
4 auc       0.957 
5 f1_score  0.0293

5.6 Predict

For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.

Code

predictions = model.predict(processed_test_ds, verbose=0)

5.7 Confusion Matrix

The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.

5.7.1 Grid Search Cross Validation for best threshold

Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.

The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.

5.7.2 Confidence threshold and Precision-Recall trade off

Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.

  • The higher the threshold, the less classes the model predicts, increasing model confidence [higher Precision] and increasing missed classes [lower Recall].
  • The lower the threshold, the more classes the model predicts, decreasing model confidence [lower Precision] and decreasing missed classes [higher Recall].

Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].

I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.

Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.

\[ F1\ macro\ avg = \frac{\sum_{i=1}^{n} F1_i}{n} \]

It is useful with imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.

5.7.2.1 f1_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_f1, best_score_f1 = config.find_optimal_threshold_cv(ytrue, y_pred_proba, f1_score)

print(f"Optimal threshold: {optimal_threshold_f1}")
Optimal threshold: 0.15000000000000002
Code
print(f"Best score: {best_score_f1}")
Best score: 0.4788653077945807
Code

# Use the optimal threshold to make predictions
final_predictions_f1 = (y_pred_proba >= optimal_threshold_f1).astype(int)

Optimal threshold f1 score: 0.15. Best score: 0.4788653.

5.7.2.2 recall_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_recall, best_score_recall = config.find_optimal_threshold_cv(ytrue, y_pred_proba, recall_score)

# Use the optimal threshold to make predictions
final_predictions_recall = (y_pred_proba >= optimal_threshold_recall).astype(int)

Optimal threshold recall: 0.05. Best score: 0.8095814.

5.7.2.3 roc_auc_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_roc, best_score_roc = config.find_optimal_threshold_cv(ytrue, y_pred_proba, roc_auc_score)

print(f"Optimal threshold: {optimal_threshold_roc}")
Optimal threshold: 0.05
Code
print(f"Best score: {best_score_roc}")
Best score: 0.8809499649742268
Code

# Use the optimal threshold to make predictions
final_predictions_roc = (y_pred_proba >= optimal_threshold_roc).astype(int)

Optimal threshold roc: 0.05. Best score: 0.88095.

5.7.3 Confusion Matrix Plot

Code
# convert probability predictions to predictions
ypred = predictions >=  optimal_threshold_recall # .05
ypred = ypred.astype(int)

# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[i], colorbar=False)
    axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
Figure 4: Multi Label Confusion matrix

5.8 Classification Report

Code

cr = classification_report(
  ytrue,
  ypred,
  target_names=config.labels,
  digits=4,
  output_dict=True
  )
df_cr = pd.DataFrame.from_dict(cr).reset_index()
Code
library(reticulate)
df_cr <- py$df_cr %>% dplyr::rename(names = index)
cols <- df_cr %>% colnames()
df_cr %>% 
  pivot_longer(
    cols = -names,
    names_to = "metrics",
    values_to = "values"
  ) %>% 
  pivot_wider(
    names_from = names,
    values_from = values
  )
Table 6: Classification report
# A tibble: 10 × 5
   metrics       precision recall `f1-score` support
   <chr>             <dbl>  <dbl>      <dbl>   <dbl>
 1 toxic            0.552  0.890      0.682     2262
 2 severe_toxic     0.236  0.917      0.375      240
 3 obscene          0.550  0.936      0.692     1263
 4 threat           0.0366 0.493      0.0681      69
 5 insult           0.471  0.915      0.622     1170
 6 identity_hate    0.116  0.720      0.200      207
 7 micro avg        0.416  0.896      0.569     5211
 8 macro avg        0.327  0.812      0.440     5211
 9 weighted avg     0.495  0.896      0.629     5211
10 samples avg      0.0502 0.0848     0.0597    5211

6 Conclusions

The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.

Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.